import copy

from knowledge_tracing.network.transformer_layer import *
from knowledge_tracing.args import ARGS

from knowledge_tracing.network.util_network import get_constraint_losses
import torch
from datasets.dataset_parser import Constants


class SAINT(nn.Module):
    def __init__(self, device, encoder_features, decoder_features, N, d_model,
                 d_ff, h, dropout):
        '''
        SAINT model with (variations of) consistency loss
        Args:
            device: device
            encoder_features: list of tuples (feature, embed_dim) for encoder.
            decoder_features: list of tuples (feature, embed_dim) for decoder.
            N: int, Number of encoder/decoder layers in encoder/decoder.
            d_model: int, size of encoder/decoder layers.
            d_ff: int, hidden size of feedforward layers.
            h: int, Number of heads in multi-headed attention.
            dropout: Dropout rate for feedforward/encoder/decoder layers.
        '''
        super().__init__()

        self.device = device
        self.encoder_features = encoder_features
        self.decoder_features = decoder_features

        # embedding
        self.encoder_embedding_layers = torch.nn.ModuleDict({
            feature.name: feature.embed_layer(dim)
            for feature, dim in encoder_features
        })
        self.decoder_embedding_layers = torch.nn.ModuleDict({
            feature.name: feature.embed_layer(dim)
            for feature, dim in decoder_features
        })

        # set model
        c = copy.deepcopy
        attn = MultiHeadedAttention(h, d_model)
        ff = PositionwiseFeedForward(d_model, d_ff, d_model, dropout)

        if ARGS.embed_sum:
            _, enc_input_dim = encoder_features[0]
            for _, dim in encoder_features:
                assert dim == enc_input_dim
            _, dec_input_dim = decoder_features[0]
            for _, dim in decoder_features:
                assert dim == dec_input_dim

        else:
            enc_input_dim = sum([dim for feature, dim in encoder_features])
            dec_input_dim = sum([dim for feature, dim in decoder_features])

        self.model = EncoderDecoder(
            Encoder(EncoderLayer(d_model, c(attn), c(ff), dropout), N,
                    enc_input_dim, d_model),
            Decoder(DecoderLayer(d_model, c(attn), c(attn), c(ff), dropout), N,
                    dec_input_dim, d_model))

        # last_output
        self.generator = Generator(d_model)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        self.augmentations = ARGS.augmentations  # deletion, insertion, replacement

    @staticmethod
    def padding_mask(target, padding_value):
        return (target != padding_value).unsqueeze(-2)

    def forward(self, data):
        '''
        Args:
            data: A dictionary of dictionary of tensors. keys ('ori', 'rep', 'ins', 'del')
            represents whether the data is an original or augmented version.
        '''

        last_output = {}
        aug_losses = {}

        for aug in data:
            if aug == 'ori' or self.training:
                # preprocessing
                aug_enc_preproc_data = {
                    feature.name: feature.preprocess_enc(data[aug][feature.name])
                    for feature, _ in self.encoder_features
                }
                aug_dec_preproc_data = {
                    feature.name: feature.preprocess_dec(data[aug][feature.name])
                    for feature, _ in self.decoder_features
                }

                # encoder & decoder embedding for original & augmented data
                aug_enc_embeds = [
                    self.encoder_embedding_layers[feature_name](x)
                    for feature_name, x in aug_enc_preproc_data.items()
                ]
                aug_dec_embeds = [
                    self.decoder_embedding_layers[feature_name](x)
                    for feature_name, x in aug_dec_preproc_data.items()
                ]

                # Sum embeddings.
                if ARGS.embed_sum:
                    aug_enc_embeds = [torch.unsqueeze(t, dim=0) for t in aug_enc_embeds]
                    aug_enc_embed_stack = torch.stack(aug_enc_embeds, dim=0)
                    aug_enc_embed_output = torch.sum(aug_enc_embed_stack, dim=0).squeeze(dim=0) # [batch_size, seq_length, embed_dim]

                    aug_dec_embeds = [torch.unsqueeze(t, dim=0) for t in aug_dec_embeds]
                    aug_dec_embed_stack = torch.stack(aug_dec_embeds, dim=0)
                    aug_dec_embed_output = torch.sum(aug_dec_embed_stack, dim=0).squeeze(dim=0) # [batch_size, seq_length, embed_dim]

                # Concatenate input embeddings.
                else:
                    aug_enc_embed_output = torch.cat(aug_enc_embeds, dim=-1)
                    aug_dec_embed_output = torch.cat(aug_dec_embeds, dim=-1)

                # padding_mask
                if 'item_idx' in ARGS.enc_feature_names:
                    src_mask = self.padding_mask(data[aug]['item_idx'], 0)  # [batch_size, seq_length, seq_length]
                    src_mask = src_mask * create_subsequent_mask(data[aug]['item_idx'].size(-1)).type_as(src_mask.data)
                else:
                    src_mask = self.padding_mask(data[aug]['tags'].sum(-1), 0)
                    src_mask = src_mask * create_subsequent_mask(data[aug]['tags'].size(1)).type_as(src_mask.data)
                tgt_mask = self.padding_mask(data[aug]['is_correct'], 0)  # [batch_size, seq_length, seq_length]
                tgt_mask = tgt_mask * create_subsequent_mask(data[aug]['is_correct'].size(-1)).type_as(tgt_mask.data)

                # forward
                augmented_dec_output, _ = self.model(aug_enc_embed_output,
                                                     aug_dec_embed_output,
                                                     src_mask,
                                                     tgt_mask)
                last_output[aug] = self.generator(augmented_dec_output)

        # constraint loss
        if self.training:
            aug_losses = get_constraint_losses(data, last_output)

        if len(aug_losses) == 0:
            return last_output, None
        else:
            return last_output, aug_losses


class SAKTLayer(nn.Module):
    """
    Single Encoder block of SAKT
    """
    def __init__(self, hidden_dim, num_head, dropout):
        super().__init__()
        self._self_attn = MultiHeadedAttention(num_head, hidden_dim, dropout)
        self._ffn = PositionwiseFeedForward(hidden_dim, hidden_dim, hidden_dim, dropout)
        self._layernorms = clones(nn.LayerNorm(hidden_dim, eps=1e-6), 2)

    def forward(self, query, key, mask=None):
        """
        query: question embeddings
        key: interaction embeddings + positional embeddings
        """
        # self-attention block
        output = self._self_attn(query=query, key=key, value=key, mask=mask)
        output = self._layernorms[0](key + output)

        # feed-forward block
        output = self._layernorms[1](output + self._ffn(output))
        return output


class SAKT(nn.Module):
    def __init__(self, device, encoder_features, d_model, h, dropout):
        '''
        SAINT model with (variations of) consistency loss
        Args:
            device: device
            encoder_features: list of tuples (feature, embed_dim) for encoder.
            d_model: int, size of encoder/decoder layers. (d_ff == d_model in SAKT)
            h: int, Number of heads in multi-headed attention.
            dropout: Dropout rate for feedforward/encoder/decoder layers.
        '''
        super().__init__()

        self.device = device
        self.encoder_features = encoder_features  # position, interaction_idx, item_idx

        # embedding
        self.encoder_embedding_layers = torch.nn.ModuleDict({
            feature.name: feature.embed_layer(dim)
            for feature, dim in encoder_features
        })
        self._question_num = Constants(ARGS.dataset_name, ARGS.data_root).NUM_ITEMS

        # model
        self.model = SAKTLayer(d_model, h, dropout)

        # last_output
        self.generator = nn.Sequential(
            nn.Linear(d_model, self._question_num + 1),
            nn.Sigmoid()
        )

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

        self.augmentations = ARGS.augmentations  # deletion, insertion, replacement

    @staticmethod
    def padding_mask(target, padding_value):
        return (target != padding_value).unsqueeze(-2)

    def _shift_tensor(self, x):
        """
        Shift tensor of shape (bsz, seq_len) by one
        """
        bsz = x.shape[0]
        shifted_x = torch.cat((torch.zeros([bsz, 1], dtype=torch.long, device=self.device), x[:, :-1]), dim=-1)
        return shifted_x

    def forward(self, data):
        '''
        Args:
            data: A dictionary of dictionary of tensors. keys ('ori', 'rep', 'ins', 'del')
            represents whether the data is an original or augmented version.
            weights: randomly generated interpolation weights for mixup augmentation (not used)
        '''

        last_output = {}
        aug_losses = {}

        for aug in data:
            if aug == 'ori' or self.training:
                # preprocessing
                aug_enc_preproc_data = {
                    feature.name: feature.preprocess_enc(data[aug][feature.name])
                    for feature, _ in self.encoder_features
                }
                aug_enc_preproc_data['interaction_idx'] = self._shift_tensor(aug_enc_preproc_data['interaction_idx'])

                # encoder embedding for original & augmented data
                aug_enc_embeds = {
                    feature_name: self.encoder_embedding_layers[feature_name](x)
                    for feature_name, x in aug_enc_preproc_data.items()
                }

                # padding_mask
                mask = self.padding_mask(data[aug]['item_idx'], 0)  # [batch_size, seq_length, seq_length]
                mask = mask * create_subsequent_mask(data[aug]['is_correct'].size(-1)).type_as(mask.data)

                # forward
                query = aug_enc_embeds['item_idx']
                key = aug_enc_embeds['interaction_idx'] + aug_enc_embeds['position']
                augmented_enc_output = self.model(query, key, mask)
                aug_output = self.generator(augmented_enc_output)
                last_output[aug] = aug_output.gather(-1, data[aug]['item_idx'].unsqueeze(-1))

        # constraint loss
        if self.training:
            aug_losses = get_constraint_losses(data, last_output)

        if len(aug_losses) == 0:
            return last_output, None
        else:
            return last_output, aug_losses


def create_subsequent_mask(size):
    """ Mask out subsequent positions."""
    a = torch.arange(size).to(ARGS.device)
    mask = a.unsqueeze(0) <= a.unsqueeze(-1)
    return mask.unsqueeze(0)
